--- title: CNN Interpreter keywords: fastai sidebar: home_sidebar summary: "Wrapper around several model interpretability techniques " description: "Wrapper around several model interpretability techniques " ---
from fast_impl.visualize import *
path = untar_data(URLs.IMAGEWOOF_320)
lbl_dict = dict(
n02086240= 'Shih-Tzu',
n02087394= 'Rhodesian ridgeback',
n02088364= 'Beagle',
n02089973= 'English foxhound',
n02093754= 'Australian terrier',
n02096294= 'Border terrier',
n02099601= 'Golden retriever',
n02105641= 'Old English sheepdog',
n02111889= 'Samoyed',
n02115641= 'Dingo'
)
dblock = DataBlock(blocks=(ImageBlock,CategoryBlock),
get_items=get_image_files,
splitter=GrandparentSplitter(valid_name='val'),
get_y=Pipeline([parent_label,lbl_dict.__getitem__]),
item_tfms=Resize(320),
batch_tfms=[*aug_transforms(size=224),Normalize.from_stats(*imagenet_stats)])
dls = dblock.dataloaders(path,bs=32)
def get_cam_resnet(arch,num_classes,pretrained=True):
body = create_body(arch,cut=-2,pretrained=pretrained);
nf = num_features_model(body)
head = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),Flatten(),nn.Linear(nf,num_classes))
model = nn.Sequential(body,head)
return model
model = get_cam_resnet(resnet34,dls.c,pretrained=False)
learn = Learner(dls,model,model_dir='/content/models',opt_func=ranger,metrics=error_rate)
learn.load('cam-resnet34')
xb,yb = dls.one_batch()
dls.show_batch((xb,yb))
cam,y_preds = generate_cam(learn.model,xb,with_preds=True)
show_at(learn.model,dls.valid,xb,yb,8)
show_at(learn.model,dls.valid,xb,yb,8,merge=False)
show_at(learn.model,dls.valid,xb,yb,8,for_cls=3)
xb,yb = dls.one_batch()
cam_batch,y_preds = generate_cam(learn.model,xb,with_preds=True)
show_cam_batch(xb,yb,cam_batch,y_preds)
interp = CamInterpreter.from_learner(learn)
cam_b,y_preds = interp.generate(with_preds=True)
interp.show_batch()
learn = cnn_learner(dls,resnet34,pretrained=False)
learn.load('resnet34')
m = learn.model.eval()
dls.valid.shuffle=True
xb,yb = dls.one_batch()
dls.valid.shuffle = True
dls.show_batch((xb,yb))
idx=7
gcam,preds = generate_gradcam(m,xb[idx],yb[idx],with_preds=True)
x_dec,y_dec = dls.decode_batch((xb[idx][None],yb[idx][None]))[0]
pred_cls = dls.vocab[preds.argmax().item()]
cam_img = CamImage(x_dec,y_dec,pred_cls,gcam[0])
cam_img.show()
for_cls = 1
lbl = dls.vocab[for_cls]
print(f"For Class: {lbl}")
cust_gcam = generate_gradcam(m,xb[idx],tensor(for_cls))
cam_img2 = cam_img.new(lbl,cust_gcam)
cam_img2.show()
gcams = []
y_preds = []
for x,y in zip(xb,yb):
gcam,preds = generate_gradcam(m,x,y,with_preds=True)
gcams.append(gcam[0])
y_preds.append(preds[0])
gcam_batch,y_preds = torch.stack(gcams),torch.stack(y_preds)
show_gradcam_batch(xb,yb,gcam_batch,y_preds,dls.vocab)
interp = GradCamInterpreter.from_learner(learn)
interp.show_batch()
interp.show_batch(guided=True,cmap='Greys')
interp.show_at(0,guided=True,cmap='Greys')